/* Copyright (c) 2025, Canaan Bright Sight Co., Ltd
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 * 1. Redistributions of source code must retain the above copyright
 * notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 * notice, this list of conditions and the following disclaimer in the
 * documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
 * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */
#include <asm/asm.h>
#include <asm/io.h>
#include <asm/spl.h>
#include <asm/types.h>
#include <command.h>
#include <common.h>
#include <cpu_func.h>
#include <dm/device-internal.h>
#include <gzip.h>
#include <image.h>
#include <linux/delay.h>
#include <linux/kernel.h>
#include <linux/mtd/mtd.h>
#include <lmb.h>
#include <mmc.h>
#include <nand.h>
#include <spi.h>
#include <spi_flash.h>
#include <spl.h>
#include <stdint.h>
#include <stdio.h>

#include "board_common.h"

#include "k230_atag.h"

#include "kendryte/pufs/pufs_hmac/pufs_hmac.h"

/* TOC (Table of Contents) 定义 */
#define K230_TOC_OFFSET        0xe0000
#define K230_TOC_SECTOR        (K230_TOC_OFFSET / 512)
#define K230_TOC_MAX_ENTRIES   10
#define K230_TOC_ENTRY_SIZE    64

#define K230_BOOT_CORE_NUM_SHIFT    (0x1)
#define K230_BOOT_CORE_NUM_MASK     (0x3 << K230_BOOT_CORE_NUM_SHIFT) 

struct k230_toc_entry {
    char name[32];
    uint64_t offset;
    uint64_t size;
    uint8_t load;
    uint8_t boot;
    uint8_t pading_1[2];
    uint32_t load_addr;
    uint8_t pading_2[4];
} __attribute__((aligned(64)));

struct k230_toc {
	uint32_t entry_count;
	struct k230_toc_entry entries[K230_TOC_MAX_ENTRIES];
};

static struct k230_toc toc;
static struct blk_desc *pblk_desc;
static uint64_t rtapp_load_addr, rtapp_size, rttapp_loaded = 0;

unsigned long k230_get_encrypted_image_load_addr(void)
{
    unsigned long addr = g_dram_base + g_dram_size - ((g_dram_size / 3) * 2);
    return addr & ~(4096-1);
}

unsigned long k230_get_rttapp_load_addr(void)
{
    unsigned long addr = g_dram_base + g_dram_size - (g_dram_size / 3);
    return addr & ~(4096-1);
}

static int k230_boot_decomp_to_load_addr(image_header_t* pUh, ulong des_len, ulong data, ulong* plen)
{
    int   ret               = 0;
    ulong img_load_addr     = (ulong)image_get_load(pUh);
    int   img_compress_algo = image_get_comp(pUh);

    // printf("image: %s load to 0x%lx, compress=%d src=0x%lx len=0x%lx\n", image_get_name(pUh), img_load_addr,
    //        img_compress_algo, data, *plen);

    if (IH_COMP_GZIP == img_compress_algo) {
        if (0x00 != (ret = gunzip((void*)img_load_addr, des_len, (void*)data, plen))) {
            printf("unzip fialed ret =%x\n", ret);
            return -1;
        }
    } else if (IH_COMP_NONE == img_compress_algo) {
        memmove((void*)img_load_addr, (void*)data, *plen);
    } else {
        printf("Error: Unsupport compress algo.\n");
        return -2;
    }

    flush_cache(img_load_addr, *plen);

    return ret;
}

static int k230_boot_check_and_get_plain_data(firmware_head_s* pfh, ulong* pplain_addr)
{
    pufs_dgst_st md;

    if (K230_IMAGE_MAGIC_NUM != pfh->magic) {
        printf("magic error 0x%08X != 0x%08X \n", K230_IMAGE_MAGIC_NUM, pfh->magic);
        return 1;
    }

    if (NONE_SECURITY == pfh->crypto_type) {
        if(SUCCESS != cb_pufs_hash(&md, (const uint8_t*)(pfh + 1), pfh->length, SHA_256)) {
            printf("sha256 error\n");
            return 2;
        }

        if (memcmp(md.dgst, pfh->verify.none_sec.signature, SHA256_SUM_LEN)) {
            printf("sha256 error");
            return 2;
        }

        if (pplain_addr) {
            *pplain_addr = (ulong)pfh + sizeof(*pfh);
        }
        return 0;
    } else if ((CHINESE_SECURITY == pfh->crypto_type) || (INTERNATIONAL_SECURITY == pfh->crypto_type)
               || (GCM_ONLY == pfh->crypto_type)) {
        printf("error, not support encrypted firmware\n");
        return 3;
    } else {
        printf("error crypto type =0x%x\n", pfh->crypto_type);
        return 4;
    }

    // never reach
    return -1;
}


#if defined(CONFIG_MTD_SPI_NAND)

#define SPINAND_NAME "spi-nand0"

__weak ulong get_nand_start_by_boot_firmre_type(en_boot_sys_t sys)
{
    ulong blk_s = IMG_PART_NOT_EXIT;
    switch (sys) {
    case BOOT_SYS_RTT:
        blk_s = RTT_SYS_IN_SPI_NAND_OFF;
        break;
    case BOOT_SYS_UBOOT:
        blk_s = UBOOT_SYS_IN_SPI_NAND_OFF;
        break;
    default:
        break;
    }
    return blk_s;
}

static struct mtd_info* get_mtd_by_name(const char* name)
{
    struct mtd_info* mtd;

    mtd_probe_devices();

    mtd = get_mtd_device_nm(name);
    if (IS_ERR_OR_NULL(mtd)) {
        printf("MTD device %s not found, ret %ld\n", name, PTR_ERR(mtd));
    }

    return mtd;
}

static inline bool mtd_is_aligned_with_block_size(struct mtd_info* mtd, u64 size)
{
    return !do_div(size, mtd->erasesize);
}

#endif

#if defined(CONFIG_MMC)
static int k230_mmc_init(void)
{
    int ret = 0;

    if (!pblk_desc) {
        struct mmc *mmc = NULL;

        if (mmc_init_device(g_boot_medium - BOOT_MEDIUM_SDIO0)) {
            ret = -1;
            goto out;
        }

        mmc = find_mmc_device(g_boot_medium - BOOT_MEDIUM_SDIO0);

        if (NULL == mmc) {
            ret = -2;
            goto out;
        }

        if (mmc_init(mmc)) {
            ret = -3;
            goto out;
        }

        pblk_desc = mmc_get_blk_desc(mmc);
        if (NULL == pblk_desc) {
            ret = -4;
            goto out;
        }
    }

out:
    return ret;
}
#endif

static void *k230_read_toc(void)
{
    int ret = 0;
    ulong blk_count;
    void *toc_buf;

    memset(&toc, 0, sizeof(toc));

    if (K230_TOC_ENTRY_SIZE != sizeof(struct k230_toc_entry)) {
        printf("%s: struct k230_toc_entry must aligned to 64B\n", __func__);
        ret = -1;
        goto out;
    }
    toc_buf = (void *)k230_get_encrypted_image_load_addr();

#if defined(CONFIG_MMC)
    if ((BOOT_MEDIUM_SDIO0 == g_boot_medium) || (BOOT_MEDIUM_SDIO1 == g_boot_medium)) {
        ret = k230_mmc_init();
        if (0 != ret) {
            printf("%s: k230_mmc_init fail: %d\n", __func__, ret);
            goto out;
        }

        blk_count = (K230_TOC_MAX_ENTRIES * K230_TOC_ENTRY_SIZE + 511) / 512;

        if (blk_dread(pblk_desc, K230_TOC_SECTOR, blk_count, toc_buf) != blk_count) {
            printf("%s: Error: Failed to read TOC from MMC\n", __func__);
            ret = -1;
            goto out;
        }
    } else
#endif

#if defined(CONFIG_MTD_SPI_NAND)
    if (g_boot_medium == BOOT_MEDIUM_NANDFLASH) {
        static struct mtd_info *mtd;
        struct mtd_oob_ops io_op = {};
        size_t len = K230_TOC_MAX_ENTRIES * K230_TOC_ENTRY_SIZE;
        size_t blocksize = 0;
        size_t amount_loaded = 0;
        size_t end = 0;
        ulong off = 0;
        u_char* buf = (u_char*)toc_buf;

        mtd = get_mtd_by_name(SPINAND_NAME);
        if (IS_ERR_OR_NULL(mtd)) {
            printf("k230_load_sys_from_spi_nand error\n");
            ret = -1;
            goto out;

        }
        blocksize = mtd->erasesize;
        off = K230_TOC_OFFSET;
        end = off + len;

        io_op.mode   = MTD_OPS_AUTO_OOB;
        io_op.len    = mtd->writesize;
        io_op.ooblen = 0;
        io_op.oobbuf = NULL;

        amount_loaded = 0;
        while (off < end) {
            if (mtd_is_aligned_with_block_size(mtd, off) && mtd_block_isbad(mtd, off)) {
                off += blocksize;
            } else {
                io_op.datbuf = &buf[amount_loaded];
                if (mtd_read_oob(mtd, off, &io_op)) {
                    printf("%s: read firmware head error\n", __func__);
                    ret = -1;
                    put_mtd_device(mtd);
                    goto out;
                }

                off += io_op.retlen;
                amount_loaded += io_op.retlen;
            }
        }
    } else
#endif
    {
        ret = -1;
        printf("%s: Error, unsupport media type %d\n", __func__, g_boot_medium);
    }

out:
    if (0 != ret) {
        return NULL;
    }

    return toc_buf;
}

#if 0
static void dump_toc_info(void)
{
    printf("\n");
    printf("=============== K230 TOC (Table of Contents) ===============\n");
    printf("TOC Location: 0x%08x (sector %u)\n", K230_TOC_OFFSET, K230_TOC_SECTOR);
    printf("Total Entries: %u\n", toc.entry_count);
    printf("------------------------------------------------------------\n");
    printf("%-16s %-12s %-12s %-6s %-6s\n",
           "Name", "Offset", "Size", "Load", "Boot");
    printf("------------------------------------------------------------\n");

    for (int i = 0; i < toc.entry_count; i++) {
        struct k230_toc_entry *entry = &toc.entries[i];

        printf("%-16s  0x%08lx    0x%08lx %-6s   0x%x\n",
               entry->name,
               entry->offset,
               entry->size,
               entry->load ? "YES" : "NO",
               entry->boot);
    }

    printf("============================================================\n");
    printf("\n");
}
#endif

static int k230_parse_toc(void *toc_buf)
{
    struct k230_toc_entry *src = (u8 *)toc_buf;

	toc.entry_count = 0;
	for (int i = 0; i < K230_TOC_MAX_ENTRIES; i++) {
        struct k230_toc_entry *dst = &toc.entries[toc.entry_count];

		if (src->name[0] == '\0') {
			break;
		}

		memcpy(dst, src, sizeof(struct k230_toc_entry));

		toc.entry_count++;
        src ++;
	}

    //dump_toc_info();

	return (toc.entry_count > 0) ? 0 : -7;
}

static uint get_gunzip_dest_length(const u8 *zipped_data, uint zipped_len)
{
    if (zipped_data == NULL || zipped_len < 4) {
        return 0;
    }
    
    uint original_len;
    memcpy(&original_len, zipped_data + zipped_len - 4, 4);
    
    return original_len;
}

static uint _k230_load_img(uint64_t offset)
{
    int ret = 0;
    ulong src_len, src_data, dst_len, plain_addr = 0;
    image_header_t *pUh = NULL;
    ulong buff = k230_get_encrypted_image_load_addr();
    firmware_head_s *pfh = (firmware_head_s*)buff;

#if defined(CONFIG_MMC)
    if ((BOOT_MEDIUM_SDIO0 == g_boot_medium) || (BOOT_MEDIUM_SDIO1 == g_boot_medium)) {
        ulong blk_s = offset / BLKSZ;
        ulong data_sect = 0;

        if (IMG_PART_NOT_EXIT == blk_s) {
            printf("%s: invalid blk num\n", __func__);
            ret = -1;
            goto out;
        }

        if (NULL == pblk_desc) {
            printf("%s: mmc is not init\n", __func__);
            ret = -1;
            goto out;
        }

        ret = blk_dread(pblk_desc, blk_s, HD_BLK_NUM, (char*)buff);
        if (ret != HD_BLK_NUM) {
            printf("%s: blk_dread fail: %d\n", __func__, ret);
            ret = -1;
            goto out;
        }

        if (pfh->magic != K230_IMAGE_MAGIC_NUM) {
            printf("%s: pfh->magic 0x%x != 0x%x blk=0x%lx buff=0x%lx\n",
                   __func__, pfh->magic, K230_IMAGE_MAGIC_NUM, blk_s, buff);
            ret = -1;
            goto out;
        }

        data_sect = DIV_ROUND_UP(pfh->length + sizeof(*pfh), BLKSZ) - HD_BLK_NUM;
        ret = blk_dread(pblk_desc, blk_s + HD_BLK_NUM, data_sect, (char*)buff + HD_BLK_NUM * BLKSZ);
        if (ret != data_sect) {
            printf("%s: blk_dread failed: %d\n", __func__, ret);
            ret = -1;
            goto out;
        }
    } else
#endif

#if defined(CONFIG_MTD_SPI_NAND)
    if (g_boot_medium == BOOT_MEDIUM_NANDFLASH) {
        u_char*          buf           = (u_char*)buff;
        size_t           blocksize     = 0;
        size_t           amount_loaded = 0;
        size_t           end           = 0;
        ulong            off           = offset;

        static struct mtd_info* mtd;
        struct mtd_oob_ops      io_op = {};

        if (IMG_PART_NOT_EXIT == off) {
            return IMG_PART_NOT_EXIT;
        }

        mtd = get_mtd_by_name(SPINAND_NAME);
        if (IS_ERR_OR_NULL(mtd)) {
            printf("%s: get_mtd_by_name fail\n", __func__);
            ret = -1;
            goto out;
        }

        blocksize = mtd->erasesize;

        end = off + sizeof(*pfh);

        io_op.mode   = MTD_OPS_AUTO_OOB;
        io_op.len    = mtd->writesize;
        io_op.ooblen = 0;
        io_op.oobbuf = NULL;

        amount_loaded = 0;
        while (off < end) {
            if (mtd_is_aligned_with_block_size(mtd, off) && mtd_block_isbad(mtd, off)) {
                off += blocksize;
            } else {
                io_op.datbuf = &buf[amount_loaded];
                if (mtd_read_oob(mtd, off, &io_op)) {
                    printf("%s: read firmware head error\n", __func__);
                    put_mtd_device(mtd);
                    ret = -1;
                    goto out;
                }

                off += io_op.retlen;
                amount_loaded += io_op.retlen;
            }
        }

        if (K230_IMAGE_MAGIC_NUM != pfh->magic) {
            printf("%s: pfh->magic 0x%x != 0x%x off=0x%lx buff=0x%lx ",
                   __func__, pfh->magic, K230_IMAGE_MAGIC_NUM, off, buff);
            put_mtd_device(mtd);
            ret = -1;
            goto out;
        }

        if (pfh->length > (mtd->writesize - sizeof(*pfh))) {
            end = off + pfh->length - (mtd->writesize - sizeof(*pfh));
            while (off < end) {
                if (mtd_is_aligned_with_block_size(mtd, off) && mtd_block_isbad(mtd, off)) {
                    off += blocksize;
                } else {
                    io_op.datbuf = &buf[amount_loaded];
                    if (mtd_read_oob(mtd, off, &io_op)) {
                        printf("%s: read firmware error\n", __func__);
                        put_mtd_device(mtd);
                        ret = -1;
                        goto out;
                    }

                    off += io_op.retlen;
                    amount_loaded += io_op.retlen;
                }
            }
        }
    } else
#endif
    {
        ret = -1;
        printf("%s: Error, unsupport media type %d\n", __func__, g_boot_medium);
        goto out;
    }

    ret = k230_boot_check_and_get_plain_data((firmware_head_s*)buff, &plain_addr);
    if (ret) {
        printf("%s: decrypt image failed: %d\n", __func__, ret);
        ret = -1;
        goto out;
    }

    pUh = (image_header_t*)(plain_addr + 4);
    if (!image_check_magic(pUh)) {
        printf("%s: bad magic\n", __func__);
        ret = -1;
        goto out;
    }

    src_len  = image_get_size(pUh);
    src_data = image_get_data(pUh);

    if (IH_TYPE_MULTI == image_get_type(pUh)) {
        image_multi_getimg(pUh, 0, &src_data, &src_len);
    }

    dst_len = get_gunzip_dest_length((u8 *)src_data, src_len);
    if (0 == dst_len) {
        dst_len = 0x6000000;
    }

    if (strncmp(image_get_name(pUh), "rtapp", 32) == 0) {
        rtapp_size = dst_len;
        rtapp_load_addr = k230_get_rttapp_load_addr();
        image_set_load(pUh, rtapp_load_addr);

        if((rtapp_load_addr + rtapp_size) > g_dram_size) {
            printf("rtapp too large\n");
            return INVALID_LOAD_ADDR;
        }

        rttapp_loaded = 1;
    }

    ret = k230_boot_decomp_to_load_addr(pUh, dst_len, src_data, &src_len);
    if (ret) {
        printf("%s: decomp_to_load_addr fail: %d\n", __func__, ret);
        ret = -1;
    }

out:
    if (0 != ret) {
        return INVALID_LOAD_ADDR;
    }

    return image_get_load(pUh);
}

static int k230_load_img(void)
{
    uint l_addr;

    for (int i = 0; i < toc.entry_count; i++) {
        if (toc.entries[i].load) {
            l_addr = _k230_load_img(toc.entries[i].offset);
            if (l_addr == INVALID_LOAD_ADDR) {
                printf("%s: load %s fail\n", __func__, toc.entries[i].name);
                continue;
            }
            toc.entries[i].load_addr = l_addr;
        }
    }

    return 0;
}

static void k230_boot_core(int core, ulong run_addr)
{
    switch (core) {
    case 0: {
        void (*run)(ulong hart, void* dtb);

        icache_disable();
        dcache_disable();

        asm volatile(".long 0x0170000b\n" ::: "memory");
        run = (void (*)(ulong, void*))run_addr;
        run(0, 0);

        break;
    }
    case 1: {

        writel(run_addr, (void*)0x91102104ULL);
        writel(0x10001000, (void*)0x9110100cULL);
        writel(0x10001, (void*)0x9110100cULL);
        writel(0x10000, (void*)0x9110100cULL);

        break;
    }
    default:
        printf("invalid core num boot to\n");
        break;
    }

}

static void k230_setup_user_tag(void)
{
    setup_start_tag();
    setup_k230_ddr_size_tag(g_dram_size);
    if(rttapp_loaded) {
        setup_k230_rtapp_tag(rtapp_size, rtapp_load_addr);
    }
    setup_end_tag();
}

static void k230_boot_img(void)
{
    struct k230_toc_entry *core0_entry = NULL;
    struct k230_toc_entry *core1_entry = NULL;

    for (int i = 0; i < toc.entry_count; i++) {
        if (toc.entries[i].boot) {
            int core_num = (toc.entries[i].boot & K230_BOOT_CORE_NUM_MASK)
                >> K230_BOOT_CORE_NUM_SHIFT;

            if (core_num == 0) {
                core0_entry = &toc.entries[i];
                printf("Found Core 0 (Small Core): %s at 0x%lx\n",
                       toc.entries[i].name, toc.entries[i].load_addr);
            } else if (core_num == 1) {
                core1_entry = &toc.entries[i];
                printf("Found Core 1 (Big Core): %s at 0x%lx\n",
                       toc.entries[i].name, toc.entries[i].load_addr);
            }
        }
    }

    if (core1_entry) {
        k230_setup_user_tag();
        k230_boot_core(1, core1_entry->load_addr);
    }

    if (core0_entry) {
        k230_boot_core(0, core0_entry->load_addr);
    }

    while (1) {
        asm volatile("wfi");
    }
}

int k230_run_system(void)
{
    int ret = 0;
    void *toc_buf;

    toc_buf = k230_read_toc();
    if (NULL == toc_buf) {
        printf("%s: k230_read_toc fail\n", __func__);
        ret = -1;
        goto out;
    }

    ret = k230_parse_toc(toc_buf);
    if (0 != ret) {
        printf("%s: not toc exit: %d\n", __func__, ret);
        goto out;
    }

    ret = k230_load_img();
    if (0 != ret) {
        printf("%s: k230_load_img fail: %d\n", __func__, ret);
        goto out;
    }

    k230_boot_img();

out:
    return ret;
}

